In [1]:
from IPython.core.display import HTML,Image
HTML('''<script> code_show=true;  function code_toggle() {  if (code_show){  $('div.input').hide();  } else {  $('div.input').show();  }  code_show = !code_show }  $( document ).ready(code_toggle); </script> <form action='javascript:code_toggle()'><input type='submit' value='Toggle Code'></form>''')
Out[1]:
In [2]:
import gc, argparse, sys, os, errno
%pylab inline
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
#sns.set()
#sns.set_style('whitegrid')
import h5py
from PIL import Image
import os
from tqdm import tqdm_notebook as tqdm
import scipy
import sklearn
from scipy.stats import pearsonr
import warnings
warnings.filterwarnings('ignore')
from scipy.io import loadmat
import IPython.display as ipd
import IPython
import librosa.display
import librosa
from pystoi import stoi
Populating the interactive namespace from numpy and matplotlib
In [3]:
select_word = np.loadtxt([i for i in os.listdir('.') if i[-3:]=='tsv'][0],dtype='str')

Spectrogram

  • spec_gt: upper
  • spec_pred: bottom
In [4]:
def MSE_pcc(A,B,ax=None):
    mse =np.mean(((A - B)**2/B.var()))
    pcc = pearsonr(A.ravel(),B.ravel())[0]
    return mse,pcc
def analyze(predict,GT_STFT_test_spkr):
    samples = predict.shape[0]
    pcc = np.zeros([samples])
    mse = np.zeros([samples])
    for i in range(samples):
        mse[i], pcc[i] = MSE_pcc(predict[i],GT_STFT_test_spkr[i])
    fig,ax=plt.subplots(1,2,figsize=(16,4))
    ax[0].hist(mse,bins=25,color='b')
    ax[0].set_title('MSE: %g(%g)' %(np.round(mse.mean(),3),np.round(mse.std(),3)))
    ax[1].hist(pcc,bins=50,color='g')
    ax[1].set_title('PCC: %g(%g)' %(np.round(pcc.mean(),3),np.round(pcc.std(),3)))
    return mse,pcc
In [5]:
spec_gt = loadmat('spectrogram_GT.mat')['GT_STFT_test_spkr']
spec_pred = loadmat('spectrogram_prediction.mat')['pred_STFT_test']
spec_concat = np.concatenate((numpy.swapaxes(spec_gt,2,1), numpy.swapaxes(spec_pred,2,1)),\
                             axis=1)
In [6]:
row_nums = 18
col_nums = 10
fig,ax=plt.subplots(row_nums,col_nums,figsize=(col_nums*2,row_nums*1.5))
cmap = cm.coolwarm
for i in range(row_nums):
    for j in range(col_nums):
        ax[i,j].imshow(spec_concat[i*col_nums+j] ,cmap=cmap)
        try:
            ax[i,j].set_title(select_word[i*col_nums+j])
        except:
            pass
#fig.suptitle('Spectrogram Demo', fontsize=14)
#fig.subplots_adjust(top=1)
fig.tight_layout()

Waveform

In [7]:
wave_gt = librosa.load('gt.wav',sr=16000)[0]
wave_pred = librosa.load('pred.wav',sr=16000)[0]
wave_merge = librosa.load('merge.wav',sr=16000)[0]

ground truth

In [8]:
display(ipd.Audio(wave_gt,rate=16000))

reconstructed audio

In [9]:
display(ipd.Audio(wave_pred,rate=16000))

merged audio

In [10]:
display(ipd.Audio(wave_merge,rate=16000))

selected audio

In [11]:
interval = 16384
samples = spec_pred.shape[0]
pcc = np.zeros([samples])
mse = np.zeros([samples])
for i in range(samples):
    mse[i], pcc[i] = MSE_pcc(spec_pred[i],spec_gt[i])
In [12]:
display(ipd.Audio(wave_merge.reshape(-1,interval*2)[np.argsort(-pcc)[:10]].ravel(),rate=16000))

gt spectrogram inversion

In [13]:
wave_gt_spec = librosa.load('on_gt_spec.wav',sr=16000)[0]
display(ipd.Audio(wave_gt_spec,rate=16000))
In [14]:
row_nums = 18
col_nums = 10
fig,ax=plt.subplots(row_nums*2,col_nums,figsize=(col_nums*2,row_nums*1.5))
for i in range(row_nums):
    for j in range(col_nums):
        try:
            ax[i*2,j].set_title(select_word[i*col_nums+j])
        except:
            pass
        ax[i*2,j].plot(wave_gt[(i*col_nums+j)*interval:(i*col_nums+j+1)*interval])
        ax[i*2+1,j].plot(wave_pred[(i*col_nums+j)*interval:(i*col_nums+j+1)*interval])
        #librosa.display.waveplot(wave_gt[(i*col_nums+j)*interval:(i*col_nums+j+1)*interval], sr=16000,ax=ax[i*2,j])
        #librosa.display.waveplot(wave_pred[(i*col_nums+j)*interval:(i*col_nums+j+1)*interval], sr=16000,ax=ax[i*2+1,j])
        ax[i*2,j].axis('off')
        ax[i*2+1,j].axis('off')
fig.tight_layout()

Visualization

drawing

Attention

This figure visualizes the dynamic of attention mask embedded in the encoder crossing time.

In [15]:
attention = loadmat('attention_mask.mat')['ams_test'][:,:,:,:,0]
average_mask = attention.mean(axis=0)
average_mask.shape
Out[15]:
(36, 15, 15)
In [16]:
row_nums = 3
col_nums = 12
fig,ax=plt.subplots(row_nums,col_nums,figsize=(col_nums*3,row_nums*3))
for i in range(row_nums):
    for j in range(col_nums):
        im = ax[i,j].imshow(average_mask[i*col_nums+j],cmap=cm.hot,vmin=mean(\
                    average_mask[average_mask!=0]),\
                   vmax=np.max(average_mask))
        ax[i,j].axis('off')
        
        
fig.subplots_adjust(right=0.84)
cbar_ax = fig.add_axes([0.85, 0.15, 0.02, 0.7])
fig.colorbar(im, cbar_ax)
Out[16]:
<matplotlib.colorbar.Colorbar at 0x7ff5b52b84d0>

Ecog

This figure visualizes the dynamic of ECoG signal crossing time.

In [17]:
ecog = loadmat('ecog.mat')['GT_STFT_test_ecog'][0,:,:,:].reshape(180,176, 15,15)[:,16:-16,:,:]
ecog_ = np.zeros([36,15,15])
for i in range(36):
    ecog_[i] = np.mean(np.max(ecog[:,i*4:(i+1)*4,:,:],1),axis=0)

row_nums = 3
col_nums = 12
fig,ax=plt.subplots(row_nums,col_nums,figsize=(col_nums*3,row_nums*3))
for i in range(row_nums):
    for j in range(col_nums):
        im = ax[i,j].imshow(ecog_[i*col_nums+j],cmap=cm.hot,vmin=mean(\
                    ecog_[ecog_!=0]),\
                   vmax=np.max(ecog_))
        ax[i,j].axis('off')
        
fig.subplots_adjust(right=0.84)
cbar_ax = fig.add_axes([0.85, 0.15, 0.02, 0.7])
fig.colorbar(im, cbar_ax)
Out[17]:
<matplotlib.colorbar.Colorbar at 0x7ff5ba4f2250>

Gradient

This figure visualizes the gradient of decoding accuracy with respect to the input ECoG signal through the model. Such gradient reflects the contribution of each electrode to the decoded speech at each time step.

In [18]:
gradient = loadmat('gradient.mat')['grad_loss2inp_test'].reshape(-1,\
                        176, 225).reshape(-1,176, 15,15)[:,16:-16,:,:]
gradient[gradient<0] = 0
gradient = np.abs(gradient)
gradient_ = np.zeros([36,15,15])
for i in range(36):
    gradient_[i] = np.mean(np.max(gradient[:,i*4:(i+1)*4,:,:],1),axis=0)

row_nums = 3
col_nums = 12
fig,ax=plt.subplots(row_nums,col_nums,figsize=(col_nums*3,row_nums*3))
for i in range(row_nums):
    for j in range(col_nums):
        im = ax[i,j].imshow(gradient_[i*col_nums+j],cmap=cm.hot,vmin=mean(\
                    gradient_[gradient_!=0]),\
                   vmax=np.max(gradient_))
        ax[i,j].axis('off')
        
fig.subplots_adjust(right=0.84)
cbar_ax = fig.add_axes([0.85, 0.15, 0.02, 0.7])
fig.colorbar(im, cbar_ax)
Out[18]:
<matplotlib.colorbar.Colorbar at 0x7ff5bf685050>

Metrics

PCC&MSE

In [19]:
spec_gt = loadmat('spectrogram_GT.mat')['GT_STFT_test_spkr']
spec_pred = loadmat('spectrogram_prediction.mat')['pred_STFT_test']
mse,pcc = analyze(spec_pred,spec_gt)
np.save('mse.npy',mse)
np.save('pcc.npy',pcc)

STOI

In [20]:
if os.path.exists('stois.npy'):
    stois = np.load('stois.npy')
else:
    stois = np.zeros([180])
    for i in range(180):
        stois[i]=stoi(wave_gt[i*interval:(i+1)*interval], wave_pred[i*interval:(i+1)*interval], \
                          16000, extended=False)
    np.save('stois.npy',stois)
In [21]:
fig,ax=plt.subplots(figsize=(8,4))
ax.hist(stois,bins=25,color='b')
ax.set_title('STOI: %g(%g)' %(np.round(stois.mean(),3),np.round(stois.std(),3)))
Out[21]:
Text(0.5, 1.0, 'STOI: 0.524(0.13)')

MCD

In [22]:
distances = np.load('mcd_distances.npy')
fig,ax=plt.subplots(figsize=(8,4))
ax.hist(distances ,bins=25,color='b')
ax.set_title('MCD: %g(%g) dB' %(np.round(distances .mean(),3),np.round(distances .std(),3)))
Out[22]:
Text(0.5, 1.0, 'MCD: 2.809(0.674) dB')